import lowp
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import TransformerModel, base_architecture


@register_model("transformer_lowp")
class TransformerLowp(TransformerModel):
    def __init__(self, *kargs, **kwargs):
        super(TransformerLowp, self).__init__(*kargs, **kwargs)

    @staticmethod
    def add_args(parser):
        super(TransformerLowp, TransformerLowp).add_args(parser)
        parser.add_argument('--precision', type=str, default='BF16',
                            help='precision of lowp. default=BF16')
        parser.add_argument('--warn-patched', default=False,
                            help='warn on lowp patched functions')
        parser.add_argument('--warn-not-patched', default=False,
                            help='warn on lowp non-patched functions')

    @classmethod
    def build_model(cls, args, task):
        # set any default arguments
        transformer_lowp(args)
        return super(TransformerLowp, TransformerLowp).build_model(args, task)

    def forward(self, *kargs, **kwargs):
        with lowp.Lowp(mode=self.args.precision,
                       warn_patched=self.args.warn_patched,
                       warn_not_patched=self.args.warn_not_patched):
            return super(TransformerLowp, self).forward(*kargs, **kwargs)

    def forward_encoder(self, *kargs, **kwargs):
        with lowp.Lowp(mode=self.args.precision,
                       warn_patched=self.args.warn_patched,
                       warn_not_patched=self.args.warn_not_patched):
            return super(TransformerLowp, self).forward_encoder(*kargs, **kwargs)

    def forward_decoder(self, *kargs, **kwargs):
        with lowp.Lowp(mode=self.args.precision,
                       warn_patched=self.args.warn_patched,
                       warn_not_patched=self.args.warn_not_patched):
            return super(TransformerLowp, self).forward_decoder(*kargs, **kwargs)


@register_model_architecture("transformer_lowp", "transformer_lowp")
def transformer_lowp(args):
    args.warn_patched = getattr(args, "warn_patched", False)
    args.warn_not_patched = getattr(args, "warn_not_patched", False)
    base_architecture(args)
